{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [],
   "source": [
    "import category_encoders as ce\n",
    "import lightgbm as lgb\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from sklearn.datasets import fetch_covtype\n",
    "from sklearn.metrics import accuracy_score, f1_score\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "from pytorch_tabular.utils import print_metrics, load_covertype_dataset\n",
    "\n",
    "# %load_ext autoreload\n",
    "# %autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [],
   "source": [
    "def load_classification_data():\n",
    "    data, cat_col_names, num_col_names, target = load_covertype_dataset()\n",
    "    test_idx = data.sample(int(0.2 * len(data)), random_state=42).index\n",
    "    test = data[data.index.isin(test_idx)]\n",
    "    train = data[~data.index.isin(test_idx)]\n",
    "    return (train, test, cat_col_names, num_col_names, target)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {
    "Collapsed": "false"
   },
   "source": [
    "# Load Forest Cover Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [],
   "source": [
    "train, test, cat_col_names, num_col_names, target_col = load_classification_data()\n",
    "train, val = train_test_split(train, random_state=42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [],
   "source": [
    "encoder = ce.OneHotEncoder(cols=cat_col_names)\n",
    "train_transform = encoder.fit_transform(train)\n",
    "val_transform = encoder.transform(val)\n",
    "test_transform = encoder.transform(test)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {
    "Collapsed": "false"
   },
   "source": [
    "## Baseline\n",
    "\n",
    "Let's use the default LightGBM model as a baseline."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "results = []\n",
    "metrics = [\n",
    "    (accuracy_score, \"Accuracy\", {}),\n",
    "    (f1_score, \"F1\", {\"average\": \"weighted\"}),\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Validation Accuracy: 0.8561396865829626 | Validation F1: 0.8555362266480759\n",
      "Holdout Accuracy: 0.8555876835166348 | Holdout F1: 0.8548755340164053\n"
     ]
    }
   ],
   "source": [
    "clf = lgb.LGBMClassifier(random_state=42, n_jobs=-1, verbose=-1)\n",
    "clf.fit(\n",
    "    train_transform.drop(columns=target_col),\n",
    "    train_transform[target_col].values.ravel(),\n",
    ")\n",
    "val_pred = clf.predict(val_transform.drop(columns=target_col))\n",
    "val_metrics = print_metrics(\n",
    "    metrics, val_transform[target_col], val_pred, \"Validation\", return_dict=True\n",
    ")\n",
    "test_pred = clf.predict(test_transform.drop(columns=target_col))\n",
    "holdout_metrics = print_metrics(\n",
    "    metrics, test_transform[target_col], test_pred, \"Holdout\", return_dict=True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "results.append(\n",
    "    {\n",
    "        \"Mode\": \"OneHot Encoding\",\n",
    "        \"Validation Acc\": val_metrics[\"Accuracy\"],\n",
    "        \"Validation F1\": val_metrics[\"F1\"],\n",
    "        \"Holdout Acc\": holdout_metrics[\"Accuracy\"],\n",
    "        \"Holdout F1\": holdout_metrics[\"F1\"],\n",
    "    }\n",
    ")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {
    "Collapsed": "false"
   },
   "source": [
    "## CategoryEmbedding Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [],
   "source": [
    "from pytorch_tabular import TabularModel\n",
    "from pytorch_tabular.models import CategoryEmbeddingModelConfig\n",
    "from pytorch_tabular.config import DataConfig, OptimizerConfig, TrainerConfig\n",
    "from pytorch_tabular.categorical_encoders import CategoricalEmbeddingTransformer\n",
    "from pytorch_tabular.models.common.heads import LinearHeadConfig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [],
   "source": [
    "data_config = DataConfig(\n",
    "    target=[target_col],  # target should always be a list. Multi-targets are only supported for regression. Multi-Task Classification is not implemented\n",
    "    continuous_cols=num_col_names,\n",
    "    categorical_cols=cat_col_names,\n",
    "    continuous_feature_transform=\"quantile_normal\",\n",
    "    normalize_continuous_features=True,\n",
    ")\n",
    "trainer_config = TrainerConfig(\n",
    "    auto_lr_find=True,  # Runs the LRFinder to automatically derive a learning rate\n",
    "    batch_size=1024,\n",
    "    max_epochs=50,\n",
    "    accelerator=\"auto\",  # can be 'cpu','gpu', 'tpu', or 'ipu'\n",
    "    devices=-1,  # -1 means use all available\n",
    ")\n",
    "optimizer_config = OptimizerConfig()\n",
    "\n",
    "head_config = LinearHeadConfig(\n",
    "    layers=\"\",  # No additional layer in head, just a mapping layer to output_dim\n",
    "    dropout=0.1,\n",
    "    initialization=\"kaiming\",\n",
    ").__dict__  # Convert to dict to pass to the model config (OmegaConf doesn't accept objects)\n",
    "\n",
    "model_config = CategoryEmbeddingModelConfig(\n",
    "    task=\"classification\",\n",
    "    layers=\"512-256-16\",  # Number of nodes in each layer\n",
    "    activation=\"LeakyReLU\",  # Activation between each layers\n",
    "    dropout=0.1,\n",
    "    initialization=\"kaiming\",\n",
    "    head=\"LinearHead\",  # Linear Head\n",
    "    head_config=head_config,  # Linear Head Config\n",
    "    learning_rate=1e-3,\n",
    "    metrics=[\"accuracy\", \"f1_score\"],\n",
    "    metrics_params=[{}, {\"average\": \"micro\"}],\n",
    "    metrics_prob_input=[False, True],\n",
    ")\n",
    "tabular_model = TabularModel(\n",
    "    data_config=data_config,\n",
    "    model_config=model_config,\n",
    "    optimizer_config=optimizer_config,\n",
    "    trainer_config=trainer_config,\n",
    "    verbose=False,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "Collapsed": "false",
    "scrolled": false,
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Seed set to 42\n",
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n",
      "You are using a CUDA device ('NVIDIA GeForce RTX 3060 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision\n",
      "/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:639: Checkpoint directory saved_models exists and is not empty.\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
      "/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.\n",
      "/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ecadcc00eea9411b8fb57cc1b57594c9",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Finding best initial lr:   0%|          | 0/100 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Trainer was signaled to stop but the required `min_epochs=1` or `min_steps=None` has not been met. Training will continue...\n",
      "`Trainer.fit` stopped: `max_steps=100` reached.\n",
      "Learning rate set to 0.001584893192461114\n",
      "Restoring states from the checkpoint path at /home/manujosephv/pytorch_tabular/docs/tutorials/.lr_find_a39eae5d-c288-4a00-ae11-afde402b0d6a.ckpt\n",
      "Restored all states from the checkpoint at /home/manujosephv/pytorch_tabular/docs/tutorials/.lr_find_a39eae5d-c288-4a00-ae11-afde402b0d6a.ckpt\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┓\n",
       "┃<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">   </span>┃<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\"> Name             </span>┃<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\"> Type                      </span>┃<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\"> Params </span>┃\n",
       "┡━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━┩\n",
       "│<span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 0 </span>│ _backbone        │ CategoryEmbeddingBackbone │  153 K │\n",
       "│<span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 1 </span>│ _embedding_layer │ Embedding1dLayer          │    896 │\n",
       "│<span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 2 </span>│ head             │ LinearHead                │    119 │\n",
       "│<span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 3 </span>│ loss             │ CrossEntropyLoss          │      0 │\n",
       "└───┴──────────────────┴───────────────────────────┴────────┘\n",
       "</pre>\n"
      ],
      "text/plain": [
       "┏━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┓\n",
       "┃\u001b[1;35m \u001b[0m\u001b[1;35m \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mName            \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mType                     \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mParams\u001b[0m\u001b[1;35m \u001b[0m┃\n",
       "┡━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━┩\n",
       "│\u001b[2m \u001b[0m\u001b[2m0\u001b[0m\u001b[2m \u001b[0m│ _backbone        │ CategoryEmbeddingBackbone │  153 K │\n",
       "│\u001b[2m \u001b[0m\u001b[2m1\u001b[0m\u001b[2m \u001b[0m│ _embedding_layer │ Embedding1dLayer          │    896 │\n",
       "│\u001b[2m \u001b[0m\u001b[2m2\u001b[0m\u001b[2m \u001b[0m│ head             │ LinearHead                │    119 │\n",
       "│\u001b[2m \u001b[0m\u001b[2m3\u001b[0m\u001b[2m \u001b[0m│ loss             │ CrossEntropyLoss          │      0 │\n",
       "└───┴──────────────────┴───────────────────────────┴────────┘\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">Trainable params</span>: 154 K                                                                                            \n",
       "<span style=\"font-weight: bold\">Non-trainable params</span>: 0                                                                                            \n",
       "<span style=\"font-weight: bold\">Total params</span>: 154 K                                                                                                \n",
       "<span style=\"font-weight: bold\">Total estimated model params size (MB)</span>: 0                                                                          \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1mTrainable params\u001b[0m: 154 K                                                                                            \n",
       "\u001b[1mNon-trainable params\u001b[0m: 0                                                                                            \n",
       "\u001b[1mTotal params\u001b[0m: 154 K                                                                                                \n",
       "\u001b[1mTotal estimated model params size (MB)\u001b[0m: 0                                                                          \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ab0bba3fec2e40a5bc07fef32597d0f6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Output()"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "`Trainer.fit` stopped: `max_epochs=50` reached.\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "<pytorch_lightning.trainer.trainer.Trainer at 0x7f251411f590>"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tabular_model.fit(train=train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "41eed282f29848289fef73e134c1e73a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Output()"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
       "┃<span style=\"font-weight: bold\">        Test metric        </span>┃<span style=\"font-weight: bold\">       DataLoader 0        </span>┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">       test_accuracy       </span>│<span style=\"color: #800080; text-decoration-color: #800080\">    0.9165160655975342     </span>│\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">       test_f1_score       </span>│<span style=\"color: #800080; text-decoration-color: #800080\">    0.9165160655975342     </span>│\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">         test_loss         </span>│<span style=\"color: #800080; text-decoration-color: #800080\">    0.19816306233406067    </span>│\n",
       "└───────────────────────────┴───────────────────────────┘\n",
       "</pre>\n"
      ],
      "text/plain": [
       "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
       "┃\u001b[1m \u001b[0m\u001b[1m       Test metric       \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m      DataLoader 0       \u001b[0m\u001b[1m \u001b[0m┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
       "│\u001b[36m \u001b[0m\u001b[36m      test_accuracy      \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m   0.9165160655975342    \u001b[0m\u001b[35m \u001b[0m│\n",
       "│\u001b[36m \u001b[0m\u001b[36m      test_f1_score      \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m   0.9165160655975342    \u001b[0m\u001b[35m \u001b[0m│\n",
       "│\u001b[36m \u001b[0m\u001b[36m        test_loss        \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m   0.19816306233406067   \u001b[0m\u001b[35m \u001b[0m│\n",
       "└───────────────────────────┴───────────────────────────┘\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[{'test_loss': 0.19816306233406067, 'test_accuracy': 0.9165160655975342, 'test_f1_score': 0.9165160655975342}]\n"
     ]
    }
   ],
   "source": [
    "result = tabular_model.evaluate(test)\n",
    "print(result)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [],
   "source": [
    "pred_df = tabular_model.predict(test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Holdout Accuracy: 0.9165160668491076 | Holdout F1: 0.9163006785311444\n"
     ]
    }
   ],
   "source": [
    "print_metrics(metrics, test[target_col], pred_df[\"prediction\"], tag=\"Holdout\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {
    "Collapsed": "false"
   },
   "source": [
    "## Extract the Learned Embedding\n",
    "\n",
    "For the models that support (CategoryEmbeddingModel and CategoryEmbeddingNODE), we can extract the learned embeddings into a sci-kit learn style Transformer. You can use this in your Sci-kit Learn pipelines and workflows as a drop in replacement."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "afcf16a800584261a8ac69f1c55f6bcb",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Output()"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/manujosephv/pytorch_tabular/src/pytorch_tabular/categorical_encoders.py:188: FutureWarning: Series.__getitem__ treating keys as positions is deprecated. In a future version, integer keys will always be treated as labels (consistent with DataFrame behavior). To access a value by position, use `ser.iloc[pos]`\n",
      "  embedding.weight[self._categorical_encoder._mapping[col].loc[key], :]\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<style>#sk-container-id-1 {color: black;}#sk-container-id-1 pre{padding: 0;}#sk-container-id-1 div.sk-toggleable {background-color: white;}#sk-container-id-1 label.sk-toggleable__label {cursor: pointer;display: block;width: 100%;margin-bottom: 0;padding: 0.3em;box-sizing: border-box;text-align: center;}#sk-container-id-1 label.sk-toggleable__label-arrow:before {content: \"▸\";float: left;margin-right: 0.25em;color: #696969;}#sk-container-id-1 label.sk-toggleable__label-arrow:hover:before {color: black;}#sk-container-id-1 div.sk-estimator:hover label.sk-toggleable__label-arrow:before {color: black;}#sk-container-id-1 div.sk-toggleable__content {max-height: 0;max-width: 0;overflow: hidden;text-align: left;background-color: #f0f8ff;}#sk-container-id-1 div.sk-toggleable__content pre {margin: 0.2em;color: black;border-radius: 0.25em;background-color: #f0f8ff;}#sk-container-id-1 input.sk-toggleable__control:checked~div.sk-toggleable__content {max-height: 200px;max-width: 100%;overflow: auto;}#sk-container-id-1 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {content: \"▾\";}#sk-container-id-1 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 input.sk-hidden--visually {border: 0;clip: rect(1px 1px 1px 1px);clip: rect(1px, 1px, 1px, 1px);height: 1px;margin: -1px;overflow: hidden;padding: 0;position: absolute;width: 1px;}#sk-container-id-1 div.sk-estimator {font-family: monospace;background-color: #f0f8ff;border: 1px dotted black;border-radius: 0.25em;box-sizing: border-box;margin-bottom: 0.5em;}#sk-container-id-1 div.sk-estimator:hover {background-color: #d4ebff;}#sk-container-id-1 div.sk-parallel-item::after {content: \"\";width: 100%;border-bottom: 1px solid gray;flex-grow: 1;}#sk-container-id-1 div.sk-label:hover label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 div.sk-serial::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: 0;}#sk-container-id-1 div.sk-serial {display: flex;flex-direction: column;align-items: center;background-color: white;padding-right: 0.2em;padding-left: 0.2em;position: relative;}#sk-container-id-1 div.sk-item {position: relative;z-index: 1;}#sk-container-id-1 div.sk-parallel {display: flex;align-items: stretch;justify-content: center;background-color: white;position: relative;}#sk-container-id-1 div.sk-item::before, #sk-container-id-1 div.sk-parallel-item::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: -1;}#sk-container-id-1 div.sk-parallel-item {display: flex;flex-direction: column;z-index: 1;position: relative;background-color: white;}#sk-container-id-1 div.sk-parallel-item:first-child::after {align-self: flex-end;width: 50%;}#sk-container-id-1 div.sk-parallel-item:last-child::after {align-self: flex-start;width: 50%;}#sk-container-id-1 div.sk-parallel-item:only-child::after {width: 0;}#sk-container-id-1 div.sk-dashed-wrapped {border: 1px dashed gray;margin: 0 0.4em 0.5em 0.4em;box-sizing: border-box;padding-bottom: 0.4em;background-color: white;}#sk-container-id-1 div.sk-label label {font-family: monospace;font-weight: bold;display: inline-block;line-height: 1.2em;}#sk-container-id-1 div.sk-label-container {text-align: center;}#sk-container-id-1 div.sk-container {/* jupyter's `normalize.less` sets `[hidden] { display: none; }` but bootstrap.min.css set `[hidden] { display: none !important; }` so we also need the `!important` here to be able to override the default hidden behavior on the sphinx rendered scikit-learn.org. See: https://github.com/scikit-learn/scikit-learn/issues/21755 */display: inline-block !important;position: relative;}#sk-container-id-1 div.sk-text-repr-fallback {display: none;}</style><div id=\"sk-container-id-1\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>LGBMClassifier(random_state=42, verbose=-1)</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-1\" type=\"checkbox\" checked><label for=\"sk-estimator-id-1\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">LGBMClassifier</label><div class=\"sk-toggleable__content\"><pre>LGBMClassifier(random_state=42, verbose=-1)</pre></div></div></div></div></div>"
      ],
      "text/plain": [
       "LGBMClassifier(random_state=42, verbose=-1)"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "transformer = CategoricalEmbeddingTransformer(tabular_model)\n",
    "train_transform = transformer.fit_transform(train)\n",
    "clf = lgb.LGBMClassifier(random_state=42, verbose=-1)\n",
    "clf.fit(train_transform.drop(columns=target_col), train_transform[target_col])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "62ce517f1b1b4c8f85f97de29c563d16",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Output()"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b1112896ff004574bbacfb617318dead",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Output()"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Validation Accuracy: 0.8600552481433353 | Validation F1: 0.8595332390375023\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Holdout Accuracy: 0.8595721244040551 | Holdout F1: 0.8590314305879913\n"
     ]
    }
   ],
   "source": [
    "val_transform = transformer.transform(val)\n",
    "val_pred = clf.predict(val_transform.drop(columns=target_col))\n",
    "val_metrics = print_metrics(\n",
    "    metrics, val_transform[target_col], val_pred, \"Validation\", return_dict=True\n",
    ")\n",
    "test_transform = transformer.transform(test)\n",
    "test_pred = clf.predict(test_transform.drop(columns=target_col))\n",
    "holdout_metrics = print_metrics(\n",
    "    metrics, test_transform[target_col], test_pred, \"Holdout\", return_dict=True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "results.append(\n",
    "    {\n",
    "        \"Mode\": \"NeuralEmbedding\",\n",
    "        \"Validation Acc\": val_metrics[\"Accuracy\"],\n",
    "        \"Validation F1\": val_metrics[\"F1\"],\n",
    "        \"Holdout Acc\": holdout_metrics[\"Accuracy\"],\n",
    "        \"Holdout F1\": holdout_metrics[\"F1\"],\n",
    "    }\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th>Mode</th>\n",
       "      <th>OneHot Encoding</th>\n",
       "      <th>NeuralEmbedding</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>Validation Acc</th>\n",
       "      <td>0.856140</td>\n",
       "      <td>0.860055</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Validation F1</th>\n",
       "      <td>0.855536</td>\n",
       "      <td>0.859533</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Holdout Acc</th>\n",
       "      <td>0.855588</td>\n",
       "      <td>0.859572</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Holdout F1</th>\n",
       "      <td>0.854876</td>\n",
       "      <td>0.859031</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "Mode            OneHot Encoding  NeuralEmbedding\n",
       "Validation Acc         0.856140         0.860055\n",
       "Validation F1          0.855536         0.859533\n",
       "Holdout Acc            0.855588         0.859572\n",
       "Holdout F1             0.854876         0.859031"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "res_df = pd.DataFrame(results).T\n",
    "res_df.columns = res_df.iloc[0]\n",
    "res_df = res_df.iloc[1:].astype(float)\n",
    "res_df"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "tabular_env",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  },
  "vscode": {
   "interpreter": {
    "hash": "ad8d5d2789703c7b1c2f7bfaada1cbd3aa0ac53e2e4e1cae5da195f5520da229"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
